from keras.models import Model
from keras.layers import Input, Dense
from keras.utils import plot_model

import matplotlib.pyplot as plt

mnist_input = Input(shape=(784,), name='input')
hidden1 = Dense(512, activation='relu', name='hidden1')(mnist_input)
hidden2 = Dense(216, activation='relu', name='hidden2')(hidden1)
hidden3 = Dense(128, activation='relu', name='hidden3')(hidden2)
output = Dense(10, activation='softmax', name='output')(hidden3)

model = Model(inputs=mnist_input, outputs=output)

# 打印网络结构
result = model.summary()
print(result)
